#-*-#-*-#-*-#-*-#-*-#-*-#-*-#-*-#-*-#-*-#
#   VISUALIZE CATE RESULTS (HEATMAPS)   #
#-*-#-*-#-*-#-*-#-*-#-*-#-*-#-*-#-*-#-*-#


#########
# SETUP #
#########

rm(list = ls())

# Load libraries
library("grf")
library("tidyverse")
library("dplyr")
library("haven") # for reading in .dta files
library("ggplot2")

# Set paths
root <- "W:/Ramon/Child_Health_Shocks/Replication"
setwd(root)
data_folder <- paste(root, "data", sep = "/")
output_cate <- paste(root, "output/HTE", sep = "/")

# Set font for graphs
par(family = "Times")

# List datasets, num features, and outcomes 
all_data_options <- list("no_repeated_obs", "no_repeated_obs_pairs")
all_outcomes <- list("change_income_postpre", "post_working", "post_mentalhealth")
selected_num_features <- "numfeat_14"

# SELECT your datasets and outcomes here if do not want to run for all options
selected_datasets <- all_data_options[1]
selected_outcomes <- all_outcomes


################
# CREATE PLOTS #
################

for (data_choice in selected_datasets){
  for (outcome in selected_outcomes){
  
  # Load the data
  estim_dat <- read_dta(paste(data_folder, sprintf("cf_output_%s_%s_%s.dta", selected_datasets, selected_num_features, outcome), sep = "/"))
  
  # Remove all attributes from STATA data
  estim_dat[] <- lapply(estim_dat, function(x) {attributes(x) <- NULL; x})
  
  # Make predicted effects into percentage points where relevant
  if (outcome %in% c("post_working", "post_mentalhealth")){
    estim_dat$pred_eff <- estim_dat$pred_eff * 100
  }
  
  # Select the correct direction of deeper colours (denoting more severe TE)
  # direction -1 for outcomes that are associated with negative tretment effects
  if (outcome %in% c("change_income_postpre", "post_working")){
    selected_direction <- -1
  }
  if (outcome %in% c("post_mentalhealth")){
    selected_direction <- 1
  }
  
  # Select the correct legend title
  # Percentage point units
  if (outcome %in% c("post_working", "post_mentalhealth")){
    selected_legend_title <- "CATE\n(pp)"
  }
  if (outcome %in% c("change_income_postpre")){
    selected_legend_title <- "CATE\n(Euros)"
  }
  
  # Create a dataframe with average predicted treatment effects
  # for each education-income quartile pairing
  average_pred_eff_df <- data.frame(matrix(nrow = 0, ncol = 3))
  all_educ_levels <- sort(unique(estim_dat$educ0))
  all_income_quartiles <- sort(unique(estim_dat$pre_avg_income_qrt))
  for (educ_lev in all_educ_levels) {
    for (income_qua in all_income_quartiles) {
      pred_eff_avg_new <- mean(estim_dat$pred_eff[estim_dat$educ0 == educ_lev & estim_dat$pre_avg_income_qrt == income_qua])
      row_new <- list(educ_lev, income_qua, pred_eff_avg_new)
      average_pred_eff_df <- rbind(average_pred_eff_df, row_new)
    }
  }
  colnames(average_pred_eff_df) <- c("educ0", "pre_avg_income_qrt", "avg_pred_eff")
  
  # make sure they are recognized as factors (required for good alignment)
  average_pred_eff_df$educ0 <- factor(average_pred_eff_df$educ0, levels=3:8)
  average_pred_eff_df$pre_avg_income_qrt <- factor(average_pred_eff_df$pre_avg_income_qrt, levels=1:4)

  ## Heatmap
  # Maternal education level & income quartile
  heatmap_educ_inc <- ggplot(average_pred_eff_df, aes(educ0, pre_avg_income_qrt, fill = avg_pred_eff)) +
    geom_tile() +
    # Income quartile scale such that higher values to mean higher income and are positioned intuitively on the plot
    # + Add labels to each level of the x and y axes 
    # + Remove space around the plot +
    scale_y_discrete(name = "Income Quartile (pre-shock)",
                    limits = factor(1:4),
                    breaks = c(4,3,2,1),
                       labels = c("Q4 (highest)",
                                  "Q3",
                                  "Q2",
                                  "Q1 (lowest)"),
                    expand = c(0,0)) +
    scale_x_discrete(name = "Education",
                     limits = factor(3:8),
                     breaks = c(3,4,5,6,7,8),
                        labels = c("Upper secondary", 
                                   "Post-secondary non-tertiary",
                                   "Short-cycle tertiary",
                                   "Bachelor's or equivalent",
                                   "Master's or equivalent",
                                   "Doctoral or equivalent"),
                     expand = c(0,0)) +
    theme(axis.text.x = element_text(angle = -90, hjust = 0),
          axis.ticks = element_blank()) +
    # Fill colors and plot style
    # Color palette: "lajolla"
    scico::scale_fill_scico(name = selected_legend_title,
                            palette = "grayC", 
                            direction = selected_direction,
                            begin=0.05) 
  
  heatmap_educ_inc 
  ggsave(paste0(output_cate, sprintf("/heatmap_%s_educ_income.pdf", outcome)), heatmap_educ_inc, "pdf", width = 6.5, height = 7)
  
  }
}

# End of file